{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 朴素贝叶斯 实现中文文本分类\n",
    "\n",
    "数据说明：文本二分类，1和0，判断是否属于政治上的出访类事件\n",
    "\n",
    "数据来源：[https://github.com/ares5221/ALBERT_text_classification](https://github.com/ares5221/ALBERT_text_classification)\n",
    "\n",
    "目录\n",
    "\n",
    "* [1.加载数据](#1.加载数据)\n",
    "* [2.文本预处理（清洗文本，分词，去除停用词）](#12.文本预处理（清洗文本，分词，去除停用词）)\n",
    "* [3.抽取文本特征：使用词袋模型](#3.抽取文本特征：使用词袋模型)\n",
    "* [4.训练贝叶斯模型（多项式贝叶斯）](#4.训练贝叶斯模型（多项式贝叶斯）)\n",
    "* [5.评价指标&测试](#5.评价指标&测试)\n",
    "* [6.交叉验证](#6.交叉验证)\n",
    "\n",
    "## 1.加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import re\n",
    "import jieba\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "import os\n",
    "# print(os.getcwd())\n",
    "\n",
    "project_dir = '/home/xijian/pycharm_projects/Magic-NLPer/'\n",
    "# 当前目录cwd\n",
    "cwd = project_dir + 'NLP/Classification/binary/naive_bayes/'\n",
    "data_dir = cwd + '../../data/zzcf/' # 数据所在目录\n",
    "stopwords_file = project_dir + 'zh_data/stopwords.txt'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   label                                            content\n",
      "0      1  当地时间2月10日，白宫发表声明称，美国总统特朗普及夫人梅拉尼娅将于2月24日至25日访问印...\n",
      "1      0  俄罗斯卫星通讯社11日最新消息，菲律宾总统杜特尔特已下令终止与美国间的《访问部队协定》(VFA)。\n",
      "2      1  据俄罗斯卫星网6日报道，土耳其总统发言人卡林表示，俄罗斯军事代表团将于近日访问安卡拉，讨论叙...\n",
      "3      0  先来说说什么是LPDDR5：要知道，手机中有两种内存颗粒，一种就是DRAM也就是大家常说的“...\n",
      "4      1  在疫情的关键时刻，出现了一件令人感动的事情，让我们明白这才是真正的好朋友，不惧疫情访问我国，...\n",
      "\n",
      "0    151\n",
      "1    149\n",
      "Name: label, dtype: int64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.\n",
      "  \n",
      "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:6: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.\n",
      "  \n"
     ]
    }
   ],
   "source": [
    "df_train = pd.read_csv(data_dir+'train.txt', encoding='UTF-8', sep='\\s', header=None,\n",
    "                       names=['label', 'content'], index_col=False)\n",
    "df_train = df_train.dropna() # 过滤含有NaN的数据\n",
    "\n",
    "df_test = pd.read_csv(data_dir+'test.txt', encoding='UTF-8', sep='\\s', header=None,\n",
    "                       names=['label', 'content'], index_col=False)\n",
    "df_test = df_test.dropna() # 过滤含有NaN的数据\n",
    "\n",
    "print(df_train.head())\n",
    "print()\n",
    "print(df_train['label'].value_counts())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(300, 2) (80, 2)\n"
     ]
    }
   ],
   "source": [
    "print(df_train.shape, df_test.shape) # (300, 2), (80, 2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 2.文本预处理（清洗文本，分词，去除停用词）"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.991 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "加载停用词...\n",
      "停用词表大小： 2613\n",
      "加载停用词...\n",
      "停用词表大小： 2613\n",
      "   label                                            content  \\\n",
      "0      1  当地时间2月10日，白宫发表声明称，美国总统特朗普及夫人梅拉尼娅将于2月24日至25日访问印...   \n",
      "1      0  俄罗斯卫星通讯社11日最新消息，菲律宾总统杜特尔特已下令终止与美国间的《访问部队协定》(VFA)。   \n",
      "2      1  据俄罗斯卫星网6日报道，土耳其总统发言人卡林表示，俄罗斯军事代表团将于近日访问安卡拉，讨论叙...   \n",
      "3      0  先来说说什么是LPDDR5：要知道，手机中有两种内存颗粒，一种就是DRAM也就是大家常说的“...   \n",
      "4      1  在疫情的关键时刻，出现了一件令人感动的事情，让我们明白这才是真正的好朋友，不惧疫情访问我国，...   \n",
      "\n",
      "                                         content_seg  \n",
      "0  时间 白宫 发表声明 美国 总统 特朗普 夫人 拉尼 日至 访问 印度 特朗普 上任 首次 ...  \n",
      "1  俄罗斯 卫星 通讯社 最新消息 菲律宾 总统 杜特 尔特 下令 终止 美国 访问 部队 协定...  \n",
      "2  俄罗斯 卫星 报道 土耳其 总统 发言人 卡林 俄罗斯 军事 代表团 近日 访问 安卡拉 讨...  \n",
      "3  LPDDR5 手机 中有 两种 内存 颗粒 一种 DRAM 常说 运行 内存 提到 LPDD...  \n",
      "4      疫情 关键时刻 一件 令人感动 事情 明白 这才 朋友 疫情 访问 我国 王毅 机场 迎接  \n",
      "(300, 3)\n"
     ]
    }
   ],
   "source": [
    "# 保留文本中文、数字、英文、短横线\n",
    "def clear_text(text):\n",
    "    p = re.compile(r\"[^\\u4e00-\\u9fa5^0-9^a-z^A-Z\\-、，。！？：；（）《》【】,!\\?:;[\\]()]\")  # 匹配不是中文、数字、字母、短横线的部分字符\n",
    "    return p.sub('', text)  # 将text中匹配到的字符替换成空字符\n",
    "\n",
    "# 加载停用词表\n",
    "def load_stopwords_file(filename):\n",
    "    print('加载停用词...')\n",
    "    stopwords=pd.read_csv(filename, index_col=False, quoting=3, sep=\"\\t\", names=['stopword'], encoding='utf-8')\n",
    "    #quoting：控制引用字符引用行为，QUOTE_MINIMAL (0), QUOTE_ALL (1), QUOTE_NONNUMERIC (2) or QUOTE_NONE (3).\n",
    "    stopwords = set(stopwords['stopword'].values)\n",
    "    print('停用词表大小：', len(stopwords))\n",
    "    return stopwords\n",
    "\n",
    "\n",
    "# 文本预处理：清洗，分词，并去除停用词\n",
    "def preprocess_text(df_content):\n",
    "    stopwords_set = load_stopwords_file(stopwords_file) # 加载停用词表\n",
    "    content_seg = []#分词后的content\n",
    "    for i,text in enumerate(df_content):\n",
    "        text = clear_text(text.strip())\n",
    "        segs = jieba.lcut(text, cut_all=False)  # cut_all=False是精确模式，True是全模式；默认模式是False 返回分词后的列表\n",
    "        segs = filter(lambda x: len(x.strip())>1, segs)  # 词长度要>1，没有保留标点符号\n",
    "        segs = filter(lambda x: x not in stopwords_set, segs)\n",
    "        # print(segs) # segs是一个filter object\n",
    "        # segs = list(segs) # segs需要一次类似“持久化”的操作，否则每次被操作一次后segs就为空了\n",
    "        content_seg.append(\" \".join(segs))\n",
    "    return content_seg\n",
    "\n",
    "df_train['content_seg'] = preprocess_text(df_train['content'])\n",
    "df_test['content_seg'] = preprocess_text(df_test['content'])\n",
    "print(df_train.head())  #\n",
    "print(df_train.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 3.抽取文本特征：使用词袋模型"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CountVectorizer train finished!\n"
     ]
    }
   ],
   "source": [
    "vectorizer = CountVectorizer(analyzer='word', # 以词为粒度做ngram\n",
    "                             max_features=4000, # 保留最常见的4000个词\n",
    "                             )\n",
    "vectorizer.fit(df_train['content_seg'].tolist())\n",
    "print('CountVectorizer train finished!')\n",
    "train_features = vectorizer.transform(df_train['content_seg'])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 4.训练贝叶斯模型（多项式贝叶斯）"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MultinomialNB train finished!\n"
     ]
    }
   ],
   "source": [
    "bayes_classifier = MultinomialNB() # 使用多项式贝叶斯\n",
    "bayes_classifier.fit(train_features, df_train['label'])\n",
    "print('MultinomialNB train finished!')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 5.评价指标&测试"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "data": {
      "text/plain": "0.9375"
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_features = vectorizer.transform(df_test['content_seg'])\n",
    "bayes_classifier.score(test_features, df_test['label']) # Return the mean accuracy"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision_score: 0.952\n",
      "recall_score: 0.930\n",
      "f1_score: 0.941\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import precision_score, recall_score, f1_score\n",
    "y_pred = bayes_classifier.predict(test_features)\n",
    "print('precision_score: %.3f' % precision_score(y_pred, df_test['label']))\n",
    "print('recall_score: %.3f' % recall_score(y_pred, df_test['label']))\n",
    "print('f1_score: %.3f' % f1_score(y_pred, df_test['label']))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 6.交叉验证"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(380, 3)\n"
     ]
    }
   ],
   "source": [
    "df_data = df_train.append(df_test)\n",
    "print(df_data.shape)\n",
    "# print(df_data.head())\n",
    "df_data = df_data.sample(frac=1.0) # shuffle 打乱顺序\n",
    "# print(df_data.head())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8689320388349514\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import StratifiedKFold\n",
    "import numpy as np\n",
    "\n",
    "def stratifiedkfold_cv(X, y, clf_class, shuffle=True, n_folds=5, **kwargs):\n",
    "    skfold = StratifiedKFold(n_splits=n_folds, shuffle=shuffle, random_state=0)\n",
    "    y_pred = y[:]\n",
    "    for train_idx, test_idx in skfold.split(X, y):\n",
    "        X_tr, y_tr = X[train_idx], y[train_idx] # 训练集\n",
    "        X_te, y_te = X[test_idx], y[test_idx] # 测试集\n",
    "        clf = clf_class(**kwargs)\n",
    "        clf.fit(X_tr,y_tr)\n",
    "        y_pred[test_idx] = clf.predict(X_te)\n",
    "    return y_pred\n",
    "\n",
    "\n",
    "y_pred = stratifiedkfold_cv(vectorizer.transform(df_data['content_seg']),\n",
    "                            np.array(df_data['label']),\n",
    "                            MultinomialNB)\n",
    "\n",
    "print(precision_score(df_data['label'], y_pred))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}